#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections import OrderedDict

from polygraphy import util
from polygraphy.comparator import RunResults
from polygraphy.json import load_json
from polygraphy.logger import G_LOGGER
from polygraphy.tools.base import Tool


class ToInput(Tool):
    """
    Combines and converts one or more input/output files generated by
    Polygraphy into a single file usable with --load-inputs.
    """

    def __init__(self):
        super().__init__("to-input")

    def add_parser_args(self, parser):
        parser.add_argument(
            "paths", help="Path(s) to file(s) containing input or output data from Polygraphy", nargs="+"
        )
        parser.add_argument("-o", "--output", help="Path to the file to generate", required=True)

    def run(self, args):
        inputs = []

        def update_inputs(new_inputs, path):
            nonlocal inputs

            if inputs and len(inputs) != len(new_inputs):
                G_LOGGER.warning(
                    "The provided files have different numbers of iterations.\n"
                    "Note: Inputs currently contains {:} iterations, but the data in {:} contains {:} iterations. "
                    "Some iterations will contain incomplete data".format(len(inputs), path, len(new_inputs))
                )

            # Pad to appropriate length
            inputs += [OrderedDict()] * (len(new_inputs) - len(inputs))

            for inp, new_inp in zip(inputs, new_inputs):
                inp.update(new_inp)

        for path in args.paths:
            # Note: It's important we have encode/decode JSON methods registered
            # for the types we care about, e.g. RunResults. Importing the class should generally guarantee this.
            data = load_json(path)
            if isinstance(data, RunResults):
                for _, iters in data.items():
                    update_inputs(iters, path)
            else:
                if not util.is_sequence(data):
                    data = [data]
                update_inputs(data, path)

        util.save_json(inputs, args.output, description="input file containing {:} iteration(s)".format(len(inputs)))
